function [Phi_Nxz,varargout] = MakeSparseTransMatrix_Nxz(parms,gesol)
% UPDATE: grid for x need not be evenly spaced

    dN = parms.grid_N(2) - parms.grid_N(1);
    dx = diff(parms.grid_x);
    if any(abs(diff(parms.grid_N) - dN)>1e-10)
        error('grids for N should be evenly spaced');
    end
    if any(dx<=0)
        error('grid for x should be increasing');
    end

    NN = parms.NN; Nx = parms.Nx; Nz = parms.Nz;
    x_lb = parms.grid_x(1); N_lb = parms.grid_N(1);
    
    N_Nxz = repmat(parms.grid_N(:),[1 Nx Nz]);
    x_Nxz = permute(repmat(parms.grid_x(:),[1 NN Nz]),[2 1 3]);
    s_Nxz = permute(repmat(parms.s(:),[1 NN Nx]),[2 3 1]);
    sigma_x_Nxz = permute(repmat(parms.sigma_x,[1 1 Nx]),[1 3 2]);
    
    Nnext_Nxz = (1 - s_Nxz).*N_Nxz + gesol.f.*(1 - N_Nxz);
    Nnext_Nxz(Nnext_Nxz<parms.grid_N(1)) = parms.grid_N(1);
    Nnext_Nxz(Nnext_Nxz>parms.grid_N(end)) = parms.grid_N(end);
    
    % dynamics for x' = x'(N,phi,z,z')
    x_next_Nxzz = zeros(NN,Nx,Nz,Nz);
    for iz_now = 1:Nz
        expected_znext = parms.StateTransitionProbs(iz_now,:)*parms.grid_z(:);
        if parms.OPTION_x_standardize_shock
            std_znext = sqrt(parms.StateTransitionProbs(iz_now,:)*(parms.grid_z(:).^2) - expected_znext^2);
        else
            std_znext = 1;
        end
        for iz_next = 1:Nz
            x_next_Nxzz(:,:,iz_now,iz_next) = (1-parms.rho_x)*parms.x_bar + parms.rho_x*x_Nxz(:,:,iz_now) ...
                + sigma_x_Nxz(:,:,iz_now)*(parms.grid_z(iz_next) - expected_znext)/std_znext;
        end
    end
    x_next_Nxzz(x_next_Nxzz<parms.grid_x(1)) = parms.grid_x(1);
    x_next_Nxzz(x_next_Nxzz>parms.grid_x(end)) = parms.grid_x(end);
    
    % precompute update weights
    Nnext_Nxz_idx1 = min(floor((Nnext_Nxz - N_lb)/dN) + 1, NN-1);
    Nnext_Nxz_idx2 = Nnext_Nxz_idx1 + 1;
    wNnext_Nxz_idx1 = (parms.grid_N(Nnext_Nxz_idx2) - Nnext_Nxz)/dN;
    wNnext_Nxz_idx2 = 1 - wNnext_Nxz_idx1;
    
    x_next_Nxzz_idx1 = arrayfun(@(X)find(parms.grid_x(1:end-1)<=X,1,'last'),x_next_Nxzz);
    x_next_Nxzz_idx2 = x_next_Nxzz_idx1 + 1;
    wx_next_Nxzz_idx1 = (parms.grid_x(x_next_Nxzz_idx2) - x_next_Nxzz)./dx(x_next_Nxzz_idx1);
    wx_next_Nxzz_idx2 = 1 - wx_next_Nxzz_idx1;
    
    Phi_idx_from = [];
    Phi_idx_to = [];
    Phi_val = [];
    
    % reshape inputs
    Nnext_Nxz_idx1 = reshape(Nnext_Nxz_idx1,[NN*Nx, Nz]);
    Nnext_Nxz_idx2 = reshape(Nnext_Nxz_idx2,[NN*Nx, Nz]);
    wNnext_Nxz_idx1 = reshape(wNnext_Nxz_idx1,[NN*Nx, Nz]);
    wNnext_Nxz_idx2 = reshape(wNnext_Nxz_idx2,[NN*Nx, Nz]);
    x_next_Nxzz_idx1 = reshape(x_next_Nxzz_idx1,[NN*Nx, Nz, Nz]);
    x_next_Nxzz_idx2 = reshape(x_next_Nxzz_idx2,[NN*Nx, Nz, Nz]);
    wx_next_Nxzz_idx1 = reshape(wx_next_Nxzz_idx1,[NN*Nx, Nz, Nz]);
    wx_next_Nxzz_idx2 = reshape(wx_next_Nxzz_idx2,[NN*Nx, Nz, Nz]);
    
    % convert into NH format
    Nxnext_idx11 = sub2ind_2d_fast([NN Nx], repmat(Nnext_Nxz_idx1,[1 1 Nz]), x_next_Nxzz_idx1);
    Nxnext_idx12 = sub2ind_2d_fast([NN Nx], repmat(Nnext_Nxz_idx1,[1 1 Nz]), x_next_Nxzz_idx2);
    Nxnext_idx21 = sub2ind_2d_fast([NN Nx], repmat(Nnext_Nxz_idx2,[1 1 Nz]), x_next_Nxzz_idx1);
    Nxnext_idx22 = sub2ind_2d_fast([NN Nx], repmat(Nnext_Nxz_idx2,[1 1 Nz]), x_next_Nxzz_idx2);
    wNxnext_idx11 = repmat(wNnext_Nxz_idx1,[1 1 Nz]).*wx_next_Nxzz_idx1;
    wNxnext_idx12 = repmat(wNnext_Nxz_idx1,[1 1 Nz]).*wx_next_Nxzz_idx2;
    wNxnext_idx21 = repmat(wNnext_Nxz_idx2,[1 1 Nz]).*wx_next_Nxzz_idx1;
    wNxnext_idx22 = repmat(wNnext_Nxz_idx2,[1 1 Nz]).*wx_next_Nxzz_idx2;
    
    IDX_N_repmat = repmat((1:NN)',[1 Nx]);
    IDX_x_repmat = permute(repmat((1:Nx)',[1 NN]),[2 1]);
    IDX_Nxnow = sub2ind_2d_fast([NN Nx],IDX_N_repmat(:),IDX_x_repmat(:));
    
    idx_from = zeros(NN*Nx,1);
    idx_to11 = zeros(NN*Nx,1);
    idx_to12 = zeros(NN*Nx,1);
    idx_to21 = zeros(NN*Nx,1);
    idx_to22 = zeros(NN*Nx,1);
    
    for iznow = 1:Nz
        idx_from(:) = sub2ind_2d_fast([NN*Nx, Nz],IDX_Nxnow, iznow);
        for iznext = 1:Nz
            idx_to11(:) = sub2ind_2d_fast([NN*Nx, Nz],Nxnext_idx11(:,iznow,iznext), ...
                iznext);
            idx_to12(:) = sub2ind_2d_fast([NN*Nx, Nz],Nxnext_idx12(:,iznow,iznext), ...
                iznext);
            idx_to21(:) = sub2ind_2d_fast([NN*Nx, Nz],Nxnext_idx21(:,iznow,iznext), ...
                iznext);
            idx_to22(:) = sub2ind_2d_fast([NN*Nx, Nz],Nxnext_idx22(:,iznow,iznext), ...
                iznext);
            Phi_idx_from = [Phi_idx_from;idx_from;idx_from;idx_from;idx_from];
            Phi_idx_to = [Phi_idx_to;idx_to11;idx_to12;idx_to21;idx_to22];
            Phi_val = [Phi_val;parms.StateTransitionProbs(iznow,iznext)*wNxnext_idx11(:,iznow,iznext); ...
                parms.StateTransitionProbs(iznow,iznext)*wNxnext_idx12(:,iznow,iznext); ...
                parms.StateTransitionProbs(iznow,iznext)*wNxnext_idx21(:,iznow,iznext); ...
                parms.StateTransitionProbs(iznow,iznext)*wNxnext_idx22(:,iznow,iznext)];
        end
    end
    
    Phi_Nxz = sparse(Phi_idx_from,Phi_idx_to,Phi_val,NN*Nx*Nz,NN*Nx*Nz);
    
    if any(Phi_Nxz(:)<0) || any(abs(1 - sum(Phi_Nxz,2))>1e-10)
        error('invalid transition matrix')
    end
    
    if nargout==2
        % compute stationary probability
        [V,~,FLAG] = eigs(Phi_Nxz',1);
        if FLAG~=0; error('Failed to find stationary distribution'); end
        prob_Nxz = reshape(V(:,1)./sum(V(:,1)),[NN Nx Nz]);
        varargout(1) = {prob_Nxz};

%         Phi_Inf = Phi_NHz^5000;
%         prob_NHz = reshape(full(Phi_Inf(1,:))',[NN NH Nz]);
%         varargout(1) = {prob_NHz};
    elseif nargout>2
        error('At most two outputs.');
    end

end

function ind = sub2ind_2d_fast(sz, i1, i2)
    
    ind = i1 + (i2-1)*sz(1);

end